import os
import pandas as pd
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import save_image
from tqdm import tqdm
from models.FinetuneVTmodels import MIL_VT_FineTune
from models.MIL_VT import *




def optimize_embeddings_pgd(model, cur_input, target_emb, learning_rate, epsilon, l2_dist_threshold, cosine_sim_threshold,mil_emb=False):
    """
    Adjusts initial input to match target embedding using Projected Gradient Descent.
    Args:
        model (torch.nn.Module): Model for generating embeddings.
        cur_input (torch.Tensor): Input tensor to be optimized.
        target_emb (torch.Tensor): Target embedding to match.
        learning_rate (float): Learning rate for gradient descent.
        epsilon (float): Epsilon value for controlling perturbations in PGD.
        l2_dist_threshold (float): Threshold for squared L2 distance.
        cosine_sim_threshold (float): Threshold for cosine similarity.
    Returns:
        torch.Tensor: Optimized input tensor.
        list: L1 distances over iterations.
        list: Cosine similarities over iterations.
        list: Losses over iterations.
    """
    org_input = cur_input.clone()
    

    squared_l2_distance = float('inf')
    cosine_sim_arr = []
    loss_arr = []
    l1_dist_arr = []
    exit_counter = 0
    squared_l2_distance_arr=[]

    iteration_count = 0
    cosine_sim = 0
    squared_l2_distance =float('inf')
    
    while squared_l2_distance >= l2_dist_threshold or cosine_sim <= cosine_sim_threshold:
        cur_input = cur_input.clone().detach().requires_grad_(True)
        
        #cur_vit_emb,mil_cur_emb= model.forward_features(cur_input)
        
        if mil_emb:
            _,cur_emb = model(cur_input)
        
        else:
            cur_emb,_ = model.forward_features(cur_input)

        loss = F.mse_loss(target_emb, cur_emb)
        loss_arr.append(loss.item())

        cur_input.grad = None
        loss.backward(retain_graph=True)
        grad = cur_input.grad

        updated_input = cur_input - learning_rate * grad

    
        projected_input = cur_input + torch.clamp(updated_input - cur_input, -epsilon, epsilon)

        #with torch.no_grad():
            #updated_vit_emb,updated_mil_emb = model.forward_features(projected_input)
        
        if mil_emb:
            with torch.no_grad():
                _,updated_emb = model(projected_input)
        
        else:
            with torch.no_grad():
                updated_emb,_ = model.forward_features(projected_input)

        squared_l2_distance = torch.sum((target_emb - updated_emb)**2).item()

        updated_l1_dist = torch.sum(torch.abs(projected_input.detach() - org_input)).item()
        cosine_sim = F.cosine_similarity(target_emb, updated_emb)

        #print(iteration_count, '\n')
        #print("Squared L2 Distance:", squared_l2_distance)
        #print("Cosine Similarity:", cosine_sim.detach().cpu().item())

        l1_dist_arr.append(updated_l1_dist)
        cosine_sim_arr.append(cosine_sim.detach().cpu().item())
        squared_l2_distance_arr.append(squared_l2_distance)
        
        cur_input = projected_input
        iteration_count += 1
    

    return cur_input.detach(), l1_dist_arr, cosine_sim_arr, loss_arr,squared_l2_distance_arr,iteration_count,cosine_sim,squared_l2_distance



def inverse_normalize():
    """
    Returns an inverse normalization transformation.
    Returns:
        transforms.Normalize: Transformation to apply inverse normalization.
    """
    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    return transforms.Normalize(mean=[-m/s for m, s in zip(mean, std)], std=[1/s for s in std])


def run_pgd_attack(lr_val, ep_val):
    
    """
    Calculate the model's accuracy and attack success rate for different learning rates and maximum allowed changes.

    Args:
        lr_val (float): Learning rate.
        ep_val (float): Maximum allowed change.

    Returns:
        None: The function prints the results, and the modified images are saved in the specified output directory.
    """
    
    result = []
    root_path = '/path/to/images/'
    csv_path = '/path/to/csv/'

    test_df = pd.read_csv(csv_path)
    test_df = test_df.sample(n=200, random_state=42)

    correct_input_predictions = 0
    correct_target_predictions = 0
    num_samples = 0
    exit_counter = 0
    cos_sim = 0.95
    l2_dis_thres = 16
    lr = lr_val
    epsilon_val = ep_val

    transform = transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])

    for idx, row in tqdm(test_df.iterrows()):
        input_image_name = row['input_id_code']
        target_image_name = row['target_id_code']
        input_label = int(row['input_diagnosis'])
        target_label = int(row['target_diagnosis'])

        input_image_path = os.path.join(root_path, input_image_name + ".png")
        target_image_path = os.path.join(root_path, target_image_name + ".png")

        if not (os.path.exists(input_image_path) and os.path.exists(target_image_path)):
            print(f"{input_image_path} or {target_image_path} doesn't exist")
            continue

        image = Image.open(input_image_path)
        input_image = transform(image).unsqueeze(0).to(device)

        image2 = Image.open(target_image_path)
        target_image = transform(image2).unsqueeze(0).to(device)

        with torch.no_grad():
            target_vit_embedding, target_mil_embedding = model.forward_features(target_image)

        optimized_image, l1_dist_arr, cosine_sim_arr, loss_arr, total_steps, cosine_sim, l2_dist = optimize_embeddings_pgd(
            model, input_image, target_vit_embedding, learning_rate=lr, epsilon=epsilon_val,
            l2_dist_threshold=l2_dis_thres, cosine_sim_threshold=cos_sim)

        optimized_image_inv = inverse_normalize()(optimized_image)

        with torch.no_grad():
            output, M = model(optimized_image)

        probabilities = torch.nn.functional.softmax(output, dim=1)

        predicted_class = torch.argmax(probabilities, dim=1).item()

        test_df.at[idx, 'predicted_label'] = predicted_class

        if predicted_class == input_label:
            correct_input_predictions += 1

        if predicted_class == target_label:
            correct_target_predictions += 1

        num_samples += 1

        modified_image_filename = f"input_{input_image_name}->target_{target_image_name}.png"
        save_image(optimized_image_inv, f'/directory/to/save/{modified_image_filename}')
    
    print(f"lr_{lr_val}, epsilon_{ep_val}")
    

    acc_in = correct_input_predictions / num_samples
    print(f'accuracy_in:{acc_in * 100:.2f}%')

    acc_target = correct_target_predictions / num_samples
    print(f'accuracy_target:{acc_target * 100:.2f}%')



# Create an instance of MIL_VT finetuned model
device = "cuda" if torch.cuda.is_available() else "cpu"
model =MIL_VT_FineTune()
checkpoint_path = '/path/to/saved/weight/file'
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['state_dict'])
model.eval()
model.to(device)


lr_list = [0.9,0.09]
epsilon_list = [0.1,0.02]

for lr_val in lr_list:
    for ep_val in epsilon_list:
        run_pgd_attack(lr_val, ep_val)
